{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c79b790b",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "- This notebook uses data from [one million reddit confessions](https://www.kaggle.com/datasets/pavellexyr/one-million-reddit-confessions) and [prosocial-dialog]() to synthesize samples to help train safety models.\n",
    "- A [classifier](https://huggingface.co/shahules786/prosocial-classifier) trained on prosocial dialog dataset is used for pseudo labeling.\n",
    "- More information on dataset can be found [here](https://huggingface.co/datasets/shahules786/prosocial-confessions)\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/data/datasets/prosocial_confessions/prosocial-confessions.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "316d3090",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ad1e33",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -q kaggle\n",
    "! pip install -q sentence_transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "435c534a",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'kaggle'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m<cell line: 18>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m     17\u001b[0m np\u001b[38;5;241m.\u001b[39mset_printoptions(precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m)\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mkaggle\u001b[39;00m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'kaggle'"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset, Dataset\n",
    "import json\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "\n",
    "SBERT_MODEL = \"all-MiniLM-L6-v2\"\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy.spatial as sp\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "import faiss\n",
    "from collections import ChainMap\n",
    "from scipy.special import softmax\n",
    "import torch\n",
    "\n",
    "np.set_printoptions(precision=3)\n",
    "import kaggle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c431d120",
   "metadata": {},
   "outputs": [],
   "source": [
    "THRESHOLD = 0.72\n",
    "B_SIZE = 1000\n",
    "MAXLEN = 196"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ff3ec8db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download data, you can get your kaggle.json file from your account page https://www.kaggle.com/me/account\n",
    "kaggle.api.dataset_download_files(\"pavellexyr/one-million-reddit-confessions\", unzip=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56046eba",
   "metadata": {},
   "source": [
    "## Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a9e2aaf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_vectorizer(model=SBERT_MODEL):\n",
    "    return SentenceTransformer(model)\n",
    "\n",
    "\n",
    "def vectorize_text(model, texts):\n",
    "    return model.encode(texts, show_progress_bar=True)\n",
    "\n",
    "\n",
    "def build_index(vector):\n",
    "    d = vector.shape[1]\n",
    "    index = faiss.IndexFlatIP(d)\n",
    "    faiss.normalize_L2(rot_vector)\n",
    "    index.add(rot_vector)\n",
    "    return index\n",
    "\n",
    "\n",
    "def get_conversations(rot_vector, posts_vector):\n",
    "    index = build_index(rot_vector)\n",
    "    num_result = 5\n",
    "    faiss.normalize_L2(posts_vector)\n",
    "    matching_dict = defaultdict(list)\n",
    "    for idx in tqdm(range(0, posts_vector.shape[0], B_SIZE)):\n",
    "        I, D = index.search(posts_vector[idx : idx + B_SIZE], num_result)\n",
    "        sim_indices = np.argwhere(I >= THRESHOLD)\n",
    "        # return sim_indices\n",
    "        for k, j in sim_indices:\n",
    "            matching_dict[k + idx].append({D[k][j]: I[k][j]})\n",
    "\n",
    "    return matching_dict\n",
    "\n",
    "\n",
    "def get_top_n(data, n=2):\n",
    "    data = dict(ChainMap(*data))\n",
    "    data = {k: v for k, v in sorted(data.items(), key=lambda data: data[1], reverse=True)}\n",
    "    return list(data.keys())[:n]\n",
    "\n",
    "\n",
    "def convert_to_json(match_dict, all_rots, reddit_df):\n",
    "    all_data = []\n",
    "    posts = reddit_df[\"title\"].values\n",
    "    permalinks = reddit_df[\"permalink\"].values\n",
    "    for post_id, rot_list in match_dict.items():\n",
    "        rot_ids = get_top_n(rot_list)\n",
    "        ps_rots = [all_rots[idx] for idx in rot_ids]\n",
    "        all_data.append(\n",
    "            {\"context\": posts[post_id], \"rots\": ps_rots, \"source\": permalinks[post_id], \"episone_done\": True}\n",
    "        )\n",
    "\n",
    "    return all_data\n",
    "\n",
    "\n",
    "def load_confessions(path):\n",
    "    conf_df = pd.read_csv(path)\n",
    "    conf_df_filtered = conf_df.query(\"score>3.0\")\n",
    "    conf_df_filtered[\"title_len\"] = conf_df_filtered[\"title\"].map(lambda x: len(x.split()))\n",
    "    conf_df_filtered = conf_df_filtered.query(\"title_len>5\")\n",
    "    return conf_df_filtered.reset_index(drop=True)\n",
    "\n",
    "\n",
    "def get_rots(dataset):\n",
    "    rots = [item[\"rots\"] for item in dataset]\n",
    "    rots = set([x for item in rots for x in item])\n",
    "    rots = list(rots)\n",
    "    return rots"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1330d1ba",
   "metadata": {},
   "source": [
    "## read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fbc63e37",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_163420/1766470847.py:52: DtypeWarning: Columns (8) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  conf_df = pd.read_csv(path)\n",
      "/tmp/ipykernel_163420/1766470847.py:54: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  conf_df_filtered[\"title_len\"] = conf_df_filtered['title'].map(lambda x : len(x.split()))\n"
     ]
    }
   ],
   "source": [
    "conf_df_filtered = load_confessions(\"/home/shahul/Data/one-million-reddit-confessions.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f089487b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration allenai--prosocial-dialog-ebbad39ca08b6d44\n",
      "Found cached dataset json (/home/shahul/.cache/huggingface/datasets/allenai___json/allenai--prosocial-dialog-ebbad39ca08b6d44/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset(\"allenai/prosocial-dialog\", split=\"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef3b8632",
   "metadata": {},
   "source": [
    "##  Embed and match ROTs with posts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d0a3d035",
   "metadata": {},
   "outputs": [],
   "source": [
    "rots = get_rots(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3a169a67",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_vectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3580d6db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbf26de30d194f8caee2bffe04dc9c56",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/3630 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "rot_vector = vectorize_text(model, rots)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "828a26ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b87895f44b94b9dbe921258499de069",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/32 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "posts = conf_df_filtered[\"title\"].values.tolist()\n",
    "posts_vector = vectorize_text(model, posts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7bb4d884",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████| 1/1 [00:03<00:00,  3.24s/it]\n"
     ]
    }
   ],
   "source": [
    "match_dict = get_conversations(rot_vector, posts_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c4f6ad39",
   "metadata": {},
   "outputs": [],
   "source": [
    "json_data = convert_to_json(match_dict, rots, conf_df_filtered)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a5c2d7b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "99dc7c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "confession_dataset = Dataset.from_list(json_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b6face5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['context', 'rots', 'permalink', 'episone_done'],\n",
       "    num_rows: 44\n",
       "})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confession_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26477b70",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## Pseudo label"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ae2af2c",
   "metadata": {},
   "source": [
    "### Dataset loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ceb5489c",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_to_id = {\n",
    "    \"__casual__\": 0,\n",
    "    \"__needs_caution__\": 1,\n",
    "    \"__needs_intervention__\": 2,\n",
    "    \"__probably_needs_caution__\": 3,\n",
    "    \"__possibly_needs_caution__\": 4,\n",
    "}\n",
    "id_to_label = {v: k for k, v in label_to_id.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2894b020",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "\n",
    "class ProSocialDataset(Dataset):\n",
    "    def __init__(self, dataset, tokenizer):\n",
    "        super().__init__()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.sep_token = self.tokenizer.sep_token\n",
    "        self.dataset = dataset\n",
    "        self.label2id = label_to_id\n",
    "        self.id2label = {v: k for k, v in label_to_id.items()}\n",
    "        self.max_len = MAXLEN\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        encoding = {}\n",
    "\n",
    "        context = self.dataset[idx][\"context\"]\n",
    "        rots = self.dataset[idx][\"rots\"]\n",
    "        input_tokens = self.tokenizer.encode(self.dataset[idx][\"context\"], add_special_tokens=False)\n",
    "        max_len = max(0, self.max_len - (len(input_tokens) + 4))\n",
    "\n",
    "        rots = self.tokenizer.encode(\n",
    "            self.tokenizer.sep_token.join(rots),\n",
    "            add_special_tokens=False,\n",
    "            max_length=max_len,\n",
    "        )\n",
    "\n",
    "        input_ids = [0] + input_tokens + [self.tokenizer.sep_token_id] + rots + [self.tokenizer.eos_token_id]\n",
    "        input_ids = input_ids + [self.tokenizer.pad_token_id] * max(0, (self.max_len - len(input_ids)))\n",
    "        mask = [1] * len(input_ids) + [self.tokenizer.pad_token_id] * (self.max_len - len(input_ids))\n",
    "\n",
    "        encoding[\"input_ids\"] = torch.tensor(input_ids)\n",
    "        encoding[\"attention_mask\"] = torch.tensor(mask)\n",
    "\n",
    "        return encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65fe9b17",
   "metadata": {},
   "source": [
    "### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "31588090",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"shahules786/prosocial-classifier\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d1c061fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(MODEL)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"roberta-base\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a05421c",
   "metadata": {},
   "source": [
    "### Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fd1b24e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(pred_dataloader):\n",
    "    probs, labels = [], []\n",
    "    for data in tqdm(pred_dataloader):\n",
    "        pred = softmax(\n",
    "            model(data[\"input_ids\"], data[\"attention_mask\"]).logits.detach().numpy().astype(\"float16\"),\n",
    "            axis=1,\n",
    "        )\n",
    "        probs.append(pred.max(axis=1))\n",
    "        labels.append([model.config.id2label.get(pred_id) for pred_id in pred.argmax(axis=1)])\n",
    "\n",
    "    return np.concatenate(probs), np.concatenate(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "26699750",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_dataset = ProSocialDataset(confession_dataset, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "f96786ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_dataset = DataLoader(pred_dataset, batch_size=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3834fae2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                  | 0/11 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
      "100%|█████████████████████████| 11/11 [00:09<00:00,  1.22it/s]\n"
     ]
    }
   ],
   "source": [
    "prob, label = predict(pred_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09357e26",
   "metadata": {},
   "source": [
    "## Add labels to dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "73213814",
   "metadata": {},
   "outputs": [],
   "source": [
    "confession_dataset = confession_dataset.add_column(\"confidence\", np.array(prob, dtype=\"float32\"))\n",
    "confession_dataset = confession_dataset.add_column(\"safety_label\", label)\n",
    "confession_dataset = confession_dataset.add_column(\"response\", [None] * len(confession_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b0d96d9",
   "metadata": {},
   "source": [
    "## Upload to Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "262b7ba0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Pushing split train to the Hub.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ddbc321760b4830baf1086e0ecc7fd7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c478e9f802e940fe822d7c8a2acf4ca0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/15 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69560e6faccd463e89e9d25c08af59dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4bd7b783660a4abbbf9f1af5c58ccbea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "confession_dataset.push_to_hub(\"shahules786/prosocial-confessions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "936793c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "OpenAssistant",
   "language": "python",
   "name": "openassistant"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
